{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ML Simulation Training\n",
    "In this notebook you are going to train a ML policy.\n",
    "\n",
    "However, you won't use examples from the SDV as data, but other agents around it instead.\n",
    "\n",
    "This may sound like a small difference, but it has profound implications still:\n",
    "- by using data from multiple sources you're **including much more variability in your training data**;\n",
    "- two agents may have taken different choices at the same intersection, leading to **multi-modal data**;\n",
    "- the **quality of the annotated data is expected to be sensibility lower** compared to the SDV, as we're leveraging a perception system.\n",
    "\n",
    "Still, the final prize is even better than planning as this policy can potentially drive all the agents in the scene and it's not limited to the SDV only.\n",
    "\n",
    "![simulation-example](https://github.com/woven-planet/l5kit/blob/master/docs/images/simulation/simulation_example.svg?raw=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tempfile import gettempdir\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.data import LocalDataManager, ChunkedDataset\n",
    "from l5kit.dataset import AgentDataset\n",
    "from l5kit.rasterization import build_rasterizer\n",
    "from l5kit.geometry import transform_points\n",
    "from l5kit.visualization import TARGET_POINTS_COLOR, draw_trajectory\n",
    "from l5kit.planning.rasterized.model import RasterizedPlanningModel\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare data path and load cfg\n",
    "\n",
    "By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.\n",
    "\n",
    "Then, we load our config file with relative paths and other configurations (rasteriser, training params...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Download L5 Sample Dataset and install L5Kit\n",
    "import os\n",
    "RunningInCOLAB = 'google.colab' in str(get_ipython())\n",
    "if RunningInCOLAB:\n",
    "    !wget https://raw.githubusercontent.com/lyft/l5kit/master/examples/setup_notebook_colab.sh -q\n",
    "    !sh ./setup_notebook_colab.sh\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = open(\"./dataset_dir.txt\", \"r\").read().strip()\n",
    "else:\n",
    "    print(\"Not running in Google Colab.\")\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = \"/tmp/l5kit_data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dm = LocalDataManager(None)\n",
    "# get config\n",
    "cfg = load_config_data(\"./config.yaml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rasterisation\n",
    "rasterizer = build_rasterizer(cfg, dm)\n",
    "\n",
    "# ===== INIT DATASET\n",
    "train_zarr = ChunkedDataset(dm.require(cfg[\"train_data_loader\"][\"key\"])).open()\n",
    "train_dataset = AgentDataset(cfg, train_zarr, rasterizer)\n",
    "\n",
    "# plot some examples\n",
    "for idx in range(0, len(train_dataset), len(train_dataset) // 10):\n",
    "    data = train_dataset[idx]\n",
    "    im = rasterizer.to_rgb(data[\"image\"].transpose(1, 2, 0))\n",
    "    target_positions = transform_points(data[\"target_positions\"], data[\"raster_from_agent\"])\n",
    "    draw_trajectory(im, target_positions, TARGET_POINTS_COLOR)\n",
    "    plt.imshow(im)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = RasterizedPlanningModel(\n",
    "        model_arch=cfg[\"model_params\"][\"model_architecture\"],\n",
    "        num_input_channels=rasterizer.num_channels(),\n",
    "        num_targets=3 * cfg[\"model_params\"][\"future_num_frames\"],  # X, Y, Yaw * number of future states,\n",
    "        weights_scaling= [1., 1., 1.],\n",
    "        criterion=nn.MSELoss(reduction=\"none\")\n",
    "        )\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare for training\n",
    "Our `AgentDataset` inherits from PyTorch `Dataset`; so we can use it inside a `Dataloader` to enable multi-processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cfg = cfg[\"train_data_loader\"]\n",
    "train_dataloader = DataLoader(train_dataset, shuffle=train_cfg[\"shuffle\"], batch_size=train_cfg[\"batch_size\"], \n",
    "                             num_workers=train_cfg[\"num_workers\"])\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "print(train_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training loop\n",
    "Here, we purposely include a barebone training loop. Clearly, many more components can be added to enrich logging and improve performance, such as:\n",
    "- learning rate drop;\n",
    "- loss weights tuning;\n",
    "- importance sampling\n",
    "\n",
    "To name a few.\n",
    "\n",
    "\n",
    "Still, the sheer size of our dataset ensures that a reasonable performance can be obtained even with this simple loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_it = iter(train_dataloader)\n",
    "progress_bar = tqdm(range(cfg[\"train_params\"][\"max_num_steps\"]))\n",
    "losses_train = []\n",
    "model.train()\n",
    "torch.set_grad_enabled(True)\n",
    "\n",
    "for _ in progress_bar:\n",
    "    try:\n",
    "        data = next(tr_it)\n",
    "    except StopIteration:\n",
    "        tr_it = iter(train_dataloader)\n",
    "        data = next(tr_it)\n",
    "    # Forward pass\n",
    "    data = {k: v.to(device) for k, v in data.items()}\n",
    "    result = model(data)\n",
    "    loss = result[\"loss\"]\n",
    "    # Backward pass\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    losses_train.append(loss.item())\n",
    "    progress_bar.set_description(f\"loss: {loss.item()} loss(avg): {np.mean(losses_train)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the train loss curve\n",
    "We can plot the train loss against the iterations (batch-wise) to check if our model has converged."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(losses_train)), losses_train, label=\"train loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Store the model\n",
    "\n",
    "Let's store the model as a torchscript. This format allows us to re-load the model and weights without requiring the class definition later.\n",
    "\n",
    "**Take note of the path, you will use it later to evaluate your planning model!**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "to_save = torch.jit.script(model.cpu())\n",
    "path_to_save = f\"{gettempdir()}/simulation_model.pt\"\n",
    "to_save.save(path_to_save)\n",
    "print(f\"MODEL STORED at {path_to_save}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations in training your first ML policy for simulation!\n",
    "### What's Next\n",
    "\n",
    "Now that your model is trained and safely stored, you can use it to control the agents around ego. We have a notebook just for that.\n",
    "\n",
    "### [Simulation evaluation](./simulation_test.ipynb)\n",
    "In this notebook a `planning_model` will control the SDV, while the `simulation_model` you just trained will be used for all other agents.\n",
    "\n",
    "Don't worry if you don't have the resources required to train a model, we provide pre-trained models just below.\n",
    "\n",
    "## Pre-trained models\n",
    "we provide a collection of pre-trained models for the simulation task:\n",
    "- [simulation model](https://d20lyvjneielsk.cloudfront.net/simulation_model_20210416_5steps.pt) trained on agents over the semantic rasteriser with history of 0.5s;\n",
    "- [planning model](https://d20lyvjneielsk.cloudfront.net/planning_model_20210421_5steps.pt) trained on the AV over the semantic rasteriser with history of 0.5s;\n",
    "\n",
    "To use one of the models simply download the corresponding `.pt` file and load it in the evaluation notebooks."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
